{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f078650e-eda7-429c-a41d-833d75982fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import PIL\n",
    "import PIL.Image\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from keras.callbacks import TensorBoard\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "from keras import backend as K"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa29bec5-68b2-4627-aa36-699fb684fcf8",
   "metadata": {},
   "source": [
    "# Dataset Preperation"
   ]
  },
  {
   "cell_type": "raw",
   "id": "89668d9e-f1b3-401a-a82c-e1a1660c9bb4",
   "metadata": {},
   "source": [
    "tfds.list_builders()"
   ]
  },
  {
   "cell_type": "raw",
   "id": "b78d1569-540b-45ff-bf63-f8fbf336baec",
   "metadata": {},
   "source": [
    "(ds_train, ds_test), ds_info = tfds.load(\n",
    "    'kddcup99',\n",
    "    split=['train', 'test'],\n",
    "    shuffle_files=False,\n",
    "    as_supervised=False,\n",
    "    with_info=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af9354c4-286b-40a6-895e-92beb9e34be8",
   "metadata": {},
   "source": [
    "Preprocess"
   ]
  },
  {
   "cell_type": "raw",
   "id": "40f5ac81-4a76-4ab8-812e-87a5230df77a",
   "metadata": {},
   "source": [
    "ds_info"
   ]
  },
  {
   "cell_type": "raw",
   "id": "44348cee-8c8b-4019-aa1e-212bc2e5260d",
   "metadata": {},
   "source": [
    "nos = ['flag', 'label', 'protocol_type', 'service']"
   ]
  },
  {
   "cell_type": "raw",
   "id": "61464fac-9420-43c4-bf5c-f0599c075e18",
   "metadata": {},
   "source": [
    "dataset1 = np.ndarray((len(ds_test) , 38), dtype=np.float32)\n",
    "\n",
    "i = 0\n",
    "for datapoint in tqdm(ds_test.enumerate()):\n",
    "    if (i >= len(ds_test)):\n",
    "        break\n",
    "    j = 0\n",
    "    for item in datapoint[1].items():\n",
    "        if (j >= 38):\n",
    "            break\n",
    "        if item[0] not in nos:\n",
    "            dataset1[i][j] = float(item[1])\n",
    "            j += 1\n",
    "    i += 1\n",
    "print(dataset1[0])"
   ]
  },
  {
   "cell_type": "raw",
   "id": "f49c63ad-884b-4616-a027-4400ad51ec46",
   "metadata": {},
   "source": [
    "dataset2 = np.ndarray((len(ds_test), 89), dtype=np.float32)\n",
    "i = 0\n",
    "for datapoint in tqdm(ds_test.enumerate()):\n",
    "    if (i >= len(ds_test)):\n",
    "        break\n",
    "    j = 0\n",
    "    for item in datapoint[1].items():\n",
    "        if (j >= 4):\n",
    "            break\n",
    "        if item[0] in nos:\n",
    "            if (item[0] == 'protocol_type'):\n",
    "                dataset2[i][item[1]] = 1\n",
    "            elif (item[0] == 'flag'):\n",
    "                # print(item[1])\n",
    "                dataset2[i][item[1] + 3] = 1\n",
    "            elif (item[0] == 'service'):\n",
    "                dataset2[i][item[1] + 14] = 1\n",
    "            elif (item[0] == 'label'):\n",
    "                if (float(item[1]) == 1):\n",
    "                    dataset2[i][79] = 1\n",
    "                elif (float(item[1]) == 7):\n",
    "                    dataset2[i][80] = 1\n",
    "                elif (float(item[1]) == 14):\n",
    "                    dataset2[i][81] = 1\n",
    "                elif (float(item[1]) == 15):\n",
    "                    dataset2[i][82] = 1\n",
    "                elif (float(item[1]) == 16):\n",
    "                    dataset2[i][83] = 1\n",
    "                elif (float(item[1]) == 20):\n",
    "                    dataset2[i][84] = 1\n",
    "                elif (float(item[1]) == 23):\n",
    "                    dataset2[i][85] = 1\n",
    "                elif (float(item[1]) == 25):\n",
    "                    dataset2[i][86] = 1\n",
    "                elif (float(item[1]) == 27):\n",
    "                    dataset2[i][87] = 1\n",
    "                elif (float(item[1]) == 34):\n",
    "                    dataset2[i][88] = 1\n",
    "            j += 1\n",
    "    i += 1\n",
    "print(dataset2)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "7613ebf4-d96f-41a9-abc2-d8c93bbf8e13",
   "metadata": {},
   "source": [
    "dataset1_ = np.array(np.array(((dataset1 - dataset1.min(axis=0)) / (dataset1.max(axis=0) - dataset1.min(axis=0))) * 10000), dtype=np.int32) / 10000\n",
    "print(dataset1_)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "e9ad7a4b-820b-458a-8f3f-cd6837c9ddff",
   "metadata": {},
   "source": [
    "dataset_to_uniqued = np.concatenate((dataset1_, dataset2), axis=1)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "e247c648-9850-4005-bf75-4db9d0ade22c",
   "metadata": {},
   "source": [
    "uniqued_dataset = np.unique(dataset_to_uniqued, axis=0)\n",
    "print(uniqued_dataset.shape)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "dc56bbcf-7f7b-4324-a69d-b06cf369a6f0",
   "metadata": {},
   "source": [
    "np.save(\"./kdd99_dataset_FINAL_test.npy\", uniqued_dataset)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "8393cee0-e736-4676-bd91-77d05aee347c",
   "metadata": {},
   "source": [
    "del uniqued_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b1d1567e-38d1-4593-b2cb-b2e456f98a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = np.load('./kdd99_dataset_FINAL_test.npy', mmap_mode=\"r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4aca33bc-aad0-4d07-8929-a20931dfeee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "uniqued_dataset_2 = np.load('./kdd99_dataset_FINAL.npy', mmap_mode=\"r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "60337d6c-09d7-4c2d-907f-c6a6cfed83ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "uniqued_dataset_2 = np.around(uniqued_dataset_2, 2)\n",
    "uniqued_dataset_2 = np.unique(uniqued_dataset_2, axis=0)\n",
    "test_dataset = np.around(test_dataset, 2)\n",
    "test_dataset = np.unique(test_dataset, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7213af37-ab7b-438a-a46d-feac1a6aa01f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = uniqued_dataset_2"
   ]
  },
  {
   "cell_type": "raw",
   "id": "df852acf-aa8c-4033-8920-89cd567ca3b6",
   "metadata": {},
   "source": [
    "dataset = np.unique(dataset, axis=0)\n",
    "dataset.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "138af961-a389-4f19-9872-72ac3c9f6a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_labels = dataset[:, 117:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "88146245-3167-4be3-8619-937baca8e26e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_without_label = dataset[:, :117]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7db9fe80-7e45-414b-8af1-62acbfc6c34a",
   "metadata": {},
   "source": [
    "### TRACE"
   ]
  },
  {
   "cell_type": "raw",
   "id": "8e1e7fdc-010d-4501-93ff-b4886b6b3374",
   "metadata": {},
   "source": [
    "dataset_without_label = dataset_without_label[:10]\n",
    "dataset_labels = dataset_labels[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b5d4f7fa-4b84-44e4-810c-b83632e3d877",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgvklEQVR4nO3df3RT9R3/8VdobYqMxgOVFKQt1SFU6g9IJ7aIbkOzUziew3FHqsziD9jsAZwlRyddd6b2TMr8wWDDVjpBDlNc9z3oxo51M2dTKNYdpad1foX5Y6LpamrX6knQbelo7/cPvmYnpkBvCeRD+nycc88xn96bvJMD5OlNmjgsy7IEAABgsDHJHgAAAOBECBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxktP9gDDMTg4qI8++kjjx4+Xw+FI9jgAAGAYLMvS4cOHNWXKFI0Zc3LnSM6IYPnoo4+Um5ub7DEAAMAIdHZ2aurUqSd1HWdEsIwfP17S0TuclZWV5GkAAMBwhMNh5ebmRp/HT8YZESxfvAyUlZVFsAAAcIZJxNs5eNMtAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMl57sAQAk1rS1zyd7BNs+WL8o2SMAMBxnWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYbUbDU19eroKBAmZmZ8ng8amlpOea+t956qxwOR9w2a9asEQ8NAABGF9vB0tTUpKqqKtXU1Ki9vV3z589XWVmZAoHAkPtv2rRJwWAwunV2dmrChAm64YYbTnp4AAAwOtgOlg0bNmj58uVasWKFCgsLtXHjRuXm5qqhoWHI/V0ul3JycqLb/v379emnn+q222476eEBAMDoYCtY+vv71dbWJq/XG7Pu9XrV2to6rOvYunWrrrnmGuXn59u5aQAAMIql29m5t7dXAwMDcrvdMetut1vd3d0nPD4YDOqFF17Qzp07j7tfJBJRJBKJXg6Hw3bGBAAAKWZEb7p1OBwxly3Lilsbyvbt23XOOedo8eLFx92vrq5OLpcruuXm5o5kTAAAkCJsBUt2drbS0tLizqb09PTEnXX5MsuytG3bNlVUVCgjI+O4+1ZXVysUCkW3zs5OO2MCAIAUYytYMjIy5PF45Pf7Y9b9fr9KS0uPe+yePXv03nvvafny5Se8HafTqaysrJgNAACMXrbewyJJPp9PFRUVKi4uVklJiRobGxUIBFRZWSnp6NmRrq4u7dixI+a4rVu3au7cuSoqKkrM5AAAYNSwHSzl5eXq6+tTbW2tgsGgioqK1NzcHP2tn2AwGPeZLKFQSLt27dKmTZsSMzUAABhVHJZlWcke4kTC4bBcLpdCoRAvDwEnMG3t88kewbYP1i9K9ggAToFEPn/zXUIAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4IwqW+vp6FRQUKDMzUx6PRy0tLcfdPxKJqKamRvn5+XI6nbrgggu0bdu2EQ0MAABGn3S7BzQ1Namqqkr19fWaN2+etmzZorKyMh04cEB5eXlDHrNkyRJ9/PHH2rp1q7761a+qp6dHR44cOenhAQDA6OCwLMuyc8DcuXM1Z84cNTQ0RNcKCwu1ePFi1dXVxe3/hz/8QTfeeKPef/99TZgwYURDhsNhuVwuhUIhZWVljeg6gNFi2trnkz2CbR+sX5TsEQCcAol8/rb1klB/f7/a2trk9Xpj1r1er1pbW4c8Zvfu3SouLtZDDz2k8847TxdeeKHuvvtu/fvf/z7m7UQiEYXD4ZgNAACMXrZeEurt7dXAwIDcbnfMutvtVnd395DHvP/++9q3b58yMzP13HPPqbe3VytXrtQnn3xyzPex1NXV6YEHHrAzGgAASGEjetOtw+GIuWxZVtzaFwYHB+VwOPT000/r8ssv18KFC7VhwwZt3779mGdZqqurFQqFoltnZ+dIxgQAACnC1hmW7OxspaWlxZ1N6enpiTvr8oXJkyfrvPPOk8vliq4VFhbKsiz94x//0PTp0+OOcTqdcjqddkYDAAApzNYZloyMDHk8Hvn9/ph1v9+v0tLSIY+ZN2+ePvroI3322WfRtXfeeUdjxozR1KlTRzAyAAAYbWy/JOTz+fTEE09o27ZtOnjwoNasWaNAIKDKykpJR1/OWbZsWXT/pUuXauLEibrtttt04MAB7d27V/fcc49uv/12jR07NnH3BAAApCzbn8NSXl6uvr4+1dbWKhgMqqioSM3NzcrPz5ckBYNBBQKB6P5f+cpX5Pf7deedd6q4uFgTJ07UkiVL9JOf/CRx9wIAAKQ025/Dkgx8DgswfHwOCwBTJO1zWAAAAJKBYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABhvRMFSX1+vgoICZWZmyuPxqKWl5Zj7vvzyy3I4HHHb3/72txEPDQAARhfbwdLU1KSqqirV1NSovb1d8+fPV1lZmQKBwHGPe/vttxUMBqPb9OnTRzw0AAAYXWwHy4YNG7R8+XKtWLFChYWF2rhxo3Jzc9XQ0HDc4yZNmqScnJzolpaWNuKhAQDA6GIrWPr7+9XW1iav1xuz7vV61draetxjZ8+ercmTJ2vBggV66aWXjrtvJBJROByO2QAAwOhlK1h6e3s1MDAgt9sds+52u9Xd3T3kMZMnT1ZjY6N27dqlZ599VjNmzNCCBQu0d+/eY95OXV2dXC5XdMvNzbUzJgAASDHpIznI4XDEXLYsK27tCzNmzNCMGTOil0tKStTZ2alHHnlEV1111ZDHVFdXy+fzRS+Hw2GiBQCAUczWGZbs7GylpaXFnU3p6emJO+tyPFdccYXefffdY/7c6XQqKysrZgMAAKOXrWDJyMiQx+OR3++PWff7/SotLR329bS3t2vy5Ml2bhoAAIxitl8S8vl8qqioUHFxsUpKStTY2KhAIKDKykpJR1/O6erq0o4dOyRJGzdu1LRp0zRr1iz19/frqaee0q5du7Rr167E3hMAAJCybAdLeXm5+vr6VFtbq2AwqKKiIjU3Nys/P1+SFAwGYz6Tpb+/X3fffbe6uro0duxYzZo1S88//7wWLlyYuHsBAABSmsOyLCvZQ5xIOByWy+VSKBTi/SzACUxb+3yyR7Dtg/WLkj0CgFMgkc/ffJcQAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAw3oiCpb6+XgUFBcrMzJTH41FLS8uwjnvllVeUnp6uyy67bCQ3CwAARinbwdLU1KSqqirV1NSovb1d8+fPV1lZmQKBwHGPC4VCWrZsmRYsWDDiYQEAwOhkO1g2bNig5cuXa8WKFSosLNTGjRuVm5urhoaG4x53xx13aOnSpSopKRnxsAAAYHSyFSz9/f1qa2uT1+uNWfd6vWptbT3mcU8++aT+/ve/67777hvW7UQiEYXD4ZgNAACMXraCpbe3VwMDA3K73THrbrdb3d3dQx7z7rvvau3atXr66aeVnp4+rNupq6uTy+WKbrm5uXbGBAAAKWZEb7p1OBwxly3LiluTpIGBAS1dulQPPPCALrzwwmFff3V1tUKhUHTr7OwcyZgAACBFDO+Ux/+XnZ2ttLS0uLMpPT09cWddJOnw4cPav3+/2tvbtXr1aknS4OCgLMtSenq6XnzxRX3zm9+MO87pdMrpdNoZDQAApDBbZ1gyMjLk8Xjk9/tj1v1+v0pLS+P2z8rK0ptvvqmOjo7oVllZqRkzZqijo0Nz5849uekBAMCoYOsMiyT5fD5VVFSouLhYJSUlamxsVCAQUGVlpaSjL+d0dXVpx44dGjNmjIqKimKOnzRpkjIzM+PWAQAAjsV2sJSXl6uvr0+1tbUKBoMqKipSc3Oz8vPzJUnBYPCEn8kCAABgh8OyLCvZQ5xIOByWy+VSKBRSVlZWsscBjDZt7fPJHsG2D9YvSvYIAE6BRD5/811CAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeCMKlvr6ehUUFCgzM1Mej0ctLS3H3Hffvn2aN2+eJk6cqLFjx2rmzJn62c9+NuKBAQDA6JNu94CmpiZVVVWpvr5e8+bN05YtW1RWVqYDBw4oLy8vbv9x48Zp9erVuuSSSzRu3Djt27dPd9xxh8aNG6fvfe97CbkTAAAgtTksy7LsHDB37lzNmTNHDQ0N0bXCwkItXrxYdXV1w7qO66+/XuPGjdOvfvWrYe0fDoflcrkUCoWUlZVlZ1xg1Jm29vlkj2DbB+sXJXsEAKdAIp+/bb0k1N/fr7a2Nnm93ph1r9er1tbWYV1He3u7WltbdfXVVx9zn0gkonA4HLMBAIDRy1aw9Pb2amBgQG63O2bd7Xaru7v7uMdOnTpVTqdTxcXFWrVqlVasWHHMfevq6uRyuaJbbm6unTEBAECKGdGbbh0OR8xly7Li1r6spaVF+/fv1+OPP66NGzfqmWeeOea+1dXVCoVC0a2zs3MkYwIAgBRh60232dnZSktLizub0tPTE3fW5csKCgokSRdffLE+/vhj3X///brpppuG3NfpdMrpdNoZDQAApDBbZ1gyMjLk8Xjk9/tj1v1+v0pLS4d9PZZlKRKJ2LlpAAAwitn+tWafz6eKigoVFxerpKREjY2NCgQCqqyslHT05Zyuri7t2LFDkvTYY48pLy9PM2fOlHT0c1keeeQR3XnnnQm8GwAAIJXZDpby8nL19fWptrZWwWBQRUVFam5uVn5+viQpGAwqEAhE9x8cHFR1dbUOHTqk9PR0XXDBBVq/fr3uuOOOxN0LAACQ0mx/Dksy8DkswPDxOSwATJG0z2EBAABIBoIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYLz3ZA2D0mLb2+WSPYNsH6xclewQAgDjDAgAAzgAECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4IwqW+vp6FRQUKDMzUx6PRy0tLcfc99lnn9W1116rc889V1lZWSopKdEf//jHEQ8MAABGH9vB0tTUpKqqKtXU1Ki9vV3z589XWVmZAoHAkPvv3btX1157rZqbm9XW1qZvfOMbuu6669Te3n7SwwMAgNHBYVmWZeeAuXPnas6cOWpoaIiuFRYWavHixaqrqxvWdcyaNUvl5eX68Y9/PKz9w+GwXC6XQqGQsrKy7IwLg/BdQqcHjzMAUyTy+dvWGZb+/n61tbXJ6/XGrHu9XrW2tg7rOgYHB3X48GFNmDDBzk0DAIBRzNa3Nff29mpgYEButztm3e12q7u7e1jX8eijj+rzzz/XkiVLjrlPJBJRJBKJXg6Hw3bGBAAAKWZEb7p1OBwxly3LilsbyjPPPKP7779fTU1NmjRp0jH3q6urk8vlim65ubkjGRMAAKQIW8GSnZ2ttLS0uLMpPT09cWddvqypqUnLly/Xb37zG11zzTXH3be6ulqhUCi6dXZ22hkTAACkGFvBkpGRIY/HI7/fH7Pu9/tVWlp6zOOeeeYZ3Xrrrdq5c6cWLTrxm+ucTqeysrJiNgAAMHrZeg+LJPl8PlVUVKi4uFglJSVqbGxUIBBQZWWlpKNnR7q6urRjxw5JR2Nl2bJl2rRpk6644oro2ZmxY8fK5XIl8K4AAIBUZTtYysvL1dfXp9raWgWDQRUVFam5uVn5+fmSpGAwGPOZLFu2bNGRI0e0atUqrVq1Krp+yy23aPv27Sd/DwAAQMqzHSyStHLlSq1cuXLIn305Ql5++eWR3AQAAEAU3yUEAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMR7AAAADjESwAAMB4BAsAADAewQIAAIxHsAAAAOMRLAAAwHgECwAAMB7BAgAAjEewAAAA4xEsAADAeAQLAAAwHsECAACMN6Jgqa+vV0FBgTIzM+XxeNTS0nLMfYPBoJYuXaoZM2ZozJgxqqqqGumsAABglLIdLE1NTaqqqlJNTY3a29s1f/58lZWVKRAIDLl/JBLRueeeq5qaGl166aUnPTAAABh9bAfLhg0btHz5cq1YsUKFhYXauHGjcnNz1dDQMOT+06ZN06ZNm7Rs2TK5XK6THhgAAIw+toKlv79fbW1t8nq9Meter1etra0JGyoSiSgcDsdsAABg9LIVLL29vRoYGJDb7Y5Zd7vd6u7uTthQdXV1crlc0S03Nzdh1w0AAM48I3rTrcPhiLlsWVbc2smorq5WKBSKbp2dnQm7bgAAcOZJt7Nzdna20tLS4s6m9PT0xJ11ORlOp1NOpzNh1wcAAM5sts6wZGRkyOPxyO/3x6z7/X6VlpYmdDAAAIAv2DrDIkk+n08VFRUqLi5WSUmJGhsbFQgEVFlZKenoyzldXV3asWNH9JiOjg5J0meffaZ//vOf6ujoUEZGhi666KLE3AsAAJDSbAdLeXm5+vr6VFtbq2AwqKKiIjU3Nys/P1/S0Q+K+/JnssyePTv6321tbdq5c6fy8/P1wQcfnNz0AABgVLAdLJK0cuVKrVy5csifbd++PW7NsqyR3AwAAIAkvksIAACcAQgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMYjWAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8dJHclB9fb0efvhhBYNBzZo1Sxs3btT8+fOPuf+ePXvk8/n01ltvacqUKfrBD36gysrKEQ8NALBv2trnkz2CbR+sX5TsEWAI28HS1NSkqqoq1dfXa968edqyZYvKysp04MAB5eXlxe1/6NAhLVy4UN/97nf11FNP6ZVXXtHKlSt17rnn6tvf/nZC7sRocyb+owMAwMlwWJZl2Tlg7ty5mjNnjhoaGqJrhYWFWrx4serq6uL2v/fee7V7924dPHgwulZZWak33nhDr7766rBuMxwOy+VyKRQKKSsry864J8STP47nTPy/O/5Mnx782Tg9zsTHGf+TyOdvW2dY+vv71dbWprVr18ase71etba2DnnMq6++Kq/XG7P2rW99S1u3btV///tfnXXWWXHHRCIRRSKR6OVQKCTp6B1PtMHIvxJ+nUgdeWv+T7JHgKFOxb9Hp9qZ+O/dmfh38P8+8K1kj2CML/6e2Dw3MiRbwdLb26uBgQG53e6Ydbfbre7u7iGP6e7uHnL/I0eOqLe3V5MnT447pq6uTg888EDcem5urp1xAeCUcW1M9gQwFX824vX19cnlcp3UdYzoTbcOhyPmsmVZcWsn2n+o9S9UV1fL5/NFLw8ODuqTTz7RxIkTj3s7JgmHw8rNzVVnZ2fCX8bC//A4nx48zqcHj/Ppw2N9eoRCIeXl5WnChAknfV22giU7O1tpaWlxZ1N6enrizqJ8IScnZ8j909PTNXHixCGPcTqdcjqdMWvnnHOOnVGNkZWVxV+G04DH+fTgcT49eJxPHx7r02PMmJP/FBVb15CRkSGPxyO/3x+z7vf7VVpaOuQxJSUlcfu/+OKLKi4uHvL9KwAAAF9mO3l8Pp+eeOIJbdu2TQcPHtSaNWsUCASin6tSXV2tZcuWRfevrKzUhx9+KJ/Pp4MHD2rbtm3aunWr7r777sTdCwAAkNJsv4elvLxcfX19qq2tVTAYVFFRkZqbm5Wfny9JCgaDCgQC0f0LCgrU3NysNWvW6LHHHtOUKVP085//POU/g8XpdOq+++6Le2kLicXjfHrwOJ8ePM6nD4/16ZHIx9n257AAAACcbnyXEAAAMB7BAgAAjEewAAAA4xEsAADAeATLKVBfX6+CggJlZmbK4/GopaUl2SOlnLq6On3ta1/T+PHjNWnSJC1evFhvv/12ssdKeXV1dXI4HKqqqkr2KCmnq6tLN998syZOnKizzz5bl112mdra2pI9Vko5cuSIfvSjH6mgoEBjx47V+eefr9raWg0ODiZ7tDPe3r17dd1112nKlClyOBz67W9/G/Nzy7J0//33a8qUKRo7dqy+/vWv66233rJ1GwRLgjU1Namqqko1NTVqb2/X/PnzVVZWFvOr3jh5e/bs0apVq/SXv/xFfr9fR44ckdfr1eeff57s0VLW66+/rsbGRl1yySXJHiXlfPrpp5o3b57OOussvfDCCzpw4IAeffTRM/YTvk3105/+VI8//rg2b96sgwcP6qGHHtLDDz+sX/ziF8ke7Yz3+eef69JLL9XmzZuH/PlDDz2kDRs2aPPmzXr99deVk5Oja6+9VocPHx7+jVhIqMsvv9yqrKyMWZs5c6a1du3aJE00OvT09FiSrD179iR7lJR0+PBha/r06Zbf77euvvpq66677kr2SCnl3nvvta688spkj5HyFi1aZN1+++0xa9dff7118803J2mi1CTJeu6556KXBwcHrZycHGv9+vXRtf/85z+Wy+WyHn/88WFfL2dYEqi/v19tbW3yer0x616vV62trUmaanQIhUKSlJAv2EK8VatWadGiRbrmmmuSPUpK2r17t4qLi3XDDTdo0qRJmj17tn75y18me6yUc+WVV+pPf/qT3nnnHUnSG2+8oX379mnhwoVJniy1HTp0SN3d3THPjU6nU1dffbWt58YRfVszhtbb26uBgYG4L4J0u91xXwCJxLEsSz6fT1deeaWKioqSPU7K+fWvf622tjbt378/2aOkrPfff18NDQ3y+Xz64Q9/qNdee03f//735XQ6Y77qBCfn3nvvVSgU0syZM5WWlqaBgQE9+OCDuummm5I9Wkr74vlvqOfGDz/8cNjXQ7CcAg6HI+ayZVlxa0ic1atX669//av27duX7FFSTmdnp+666y69+OKLyszMTPY4KWtwcFDFxcVat26dJGn27Nl666231NDQQLAkUFNTk5566int3LlTs2bNUkdHh6qqqjRlyhTdcsstyR4v5Z3scyPBkkDZ2dlKS0uLO5vS09MTV5ZIjDvvvFO7d+/W3r17NXXq1GSPk3La2trU09Mjj8cTXRsYGNDevXu1efNmRSIRpaWlJXHC1DB58mRddNFFMWuFhYXatWtXkiZKTffcc4/Wrl2rG2+8UZJ08cUX68MPP1RdXR3Bcgrl5ORIOnqmZfLkydF1u8+NvIclgTIyMuTxeOT3+2PW/X6/SktLkzRVarIsS6tXr9azzz6rP//5zyooKEj2SClpwYIFevPNN9XR0RHdiouL9Z3vfEcdHR3ESoLMmzcv7tfy33nnneiXyiIx/vWvf2nMmNinvbS0NH6t+RQrKChQTk5OzHNjf3+/9uzZY+u5kTMsCebz+VRRUaHi4mKVlJSosbFRgUBAlZWVyR4tpaxatUo7d+7U7373O40fPz56Vsvlcmns2LFJni51jB8/Pu59QePGjdPEiRN5v1ACrVmzRqWlpVq3bp2WLFmi1157TY2NjWpsbEz2aCnluuuu04MPPqi8vDzNmjVL7e3t2rBhg26//fZkj3bG++yzz/Tee+9FLx86dEgdHR2aMGGC8vLyVFVVpXXr1mn69OmaPn261q1bp7PPPltLly4d/o0k6teY8D+PPfaYlZ+fb2VkZFhz5szhV21PAUlDbk8++WSyR0t5/FrzqfH73//eKioqspxOpzVz5kyrsbEx2SOlnHA4bN11111WXl6elZmZaZ1//vlWTU2NFYlEkj3aGe+ll14a8t/kW265xbKso7/afN9991k5OTmW0+m0rrrqKuvNN9+0dRsOy7KsRBUWAADAqcB7WAAAgPEIFgAAYDyCBQAAGI9gAQAAxiNYAACA8QgWAABgPIIFAAAYj2ABAADGI1gAAIDxCBYAAGA8ggUAABiPYAEAAMb7fwSQTcpTiIfvAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y_tr_idx = dataset_labels[:, :].argmax(axis=1)\n",
    "class_density = plt.hist(y_tr_idx, bins=10, range=(-0.5, 9.5), density=True)[0]\n",
    "# class_weights = 1 / class_density\n",
    "# class_weights = class_weights / np.max(class_weights)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "eb4ab029-8a73-43e2-abec-fc4755ab471a",
   "metadata": {},
   "source": [
    "y_tr_idx2 = test_dataset[:, :].argmax(axis=1)\n",
    "class_density2 = plt.hist(y_tr_idx2, bins=10, range=(-0.5, 9.5), density=False)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f2ba67e5-de55-48d6-b699-93310a6d02d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset[:, [3, 4]] = test_dataset[:, [4, 3]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "80917a1a-9673-4f14-b5c0-a9c821a22b47",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = np.argmax(dataset_labels, axis=1)\n",
    "mask = np.logical_or(np.logical_or(args<2, np.logical_and(args<6, args>3)), args==5)\n",
    "# mask = np.logical_or(args<=2, np.logical_and(args<6, args>3))\n",
    "\n",
    "args2 = np.argmax(test_dataset, axis=1)\n",
    "mask2 = np.logical_and(args2<5, args2>=0)\n",
    "# mask2 = np.logical_or(args2<=2, np.logical_and(args2<6, args2>3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dff29ac-b280-4e06-b39c-e73f5d93f9a9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aef79a4-4484-4b5e-847f-b3947ba5aca3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29b39b67-ee35-49c0-aee3-01644aaf4a55",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4173a2d4-6872-4916-a5fd-237953288fc1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fdfe5b54-5ae1-407f-9f7f-f359ba826840",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([4, 4, 4, ..., 4, 5, 1], dtype=int64)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args[mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4f4819dc-439f-48a2-805c-e9bf5a088af6",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_labels = dataset_labels[mask, ...]\n",
    "dataset_without_label = dataset_without_label[mask, ...]\n",
    "test_dataset = test_dataset[mask2, ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8d4cf4cd-aebf-4fb2-a1e6-c6f49a8a6397",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcqUlEQVR4nO3df3DcdZ348Vea0k3xmjC0kraSlnCHUqmiJAfXlJ4/iVM6nWF0pB5KUeDGjMWS5mBo7Y1IRwmidjgtbelBZRxBMyoqjtEjo05bKI4003gM7Yh3VFIhMZN6kxS4S23y+f7Bl9zkkkI3Tbvvbh6Pmf1j33w+u69dAvucz+5+tiTLsiwAABI2pdADAAC8EcECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8qYWeoDjMTQ0FC+++GLMmDEjSkpKCj0OAHAcsiyLw4cPx9y5c2PKlBM7RnJaBMuLL74YVVVVhR4DABiHgwcPxrnnnntCt3FaBMuMGTMi4tUHXF5eXuBpAIDj0d/fH1VVVcOv4yfitAiW194GKi8vFywAcJqZiI9z+NAtAJA8wQIAJC/vYNm5c2csX7485s6dGyUlJfGjH/3oDffZsWNH1NTURFlZWZx//vmxdevW8cwKAExSeQfLyy+/HBdffHFs2rTpuLY/cOBAXHnllbFkyZLYu3dvfO5zn4vVq1fHD37wg7yHBQAmp7w/dLt06dJYunTpcW+/devWmDdvXtxzzz0REbFgwYLYs2dPfPWrX42PfOQj+d49ADAJnfTPsDz55JNRX18/Yu1DH/pQ7NmzJ/7yl7+Muc/AwED09/ePuAAAk9dJD5bu7u6orKwcsVZZWRlHjx6N3t7eMfdpbm6OioqK4YuTxgHA5HZKviX0f79/nWXZmOuvWbduXfT19Q1fDh48eNJnBADSddJPHDd79uzo7u4esdbT0xNTp06NmTNnjrlPLpeLXC53skcDAE4TJ/0Iy6JFi6KtrW3E2mOPPRa1tbVxxhlnnOy7BwCKQN7B8tJLL0VHR0d0dHRExKtfW+7o6IjOzs6IePXtnJUrVw5v39DQEM8//3w0NTXF/v37Y/v27fHAAw/ELbfcMjGPAAAoenm/JbRnz5543/veN3y9qakpIiKuu+66ePDBB6Orq2s4XiIiqquro7W1NdasWRP33ntvzJ07N77+9a/7SjMAcNxKstc+AZuw/v7+qKioiL6+Pj9+CACniYl8/fZbQgBA8k76t4SAU+u8tT8t9Ah5+8Ndywo9ApA4R1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASN64gmXz5s1RXV0dZWVlUVNTE7t27Xrd7R966KG4+OKL48wzz4w5c+bEpz71qTh06NC4BgYAJp+8g6WlpSUaGxtj/fr1sXfv3liyZEksXbo0Ojs7x9z+8ccfj5UrV8YNN9wQzzzzTHzve9+Lp556Km688cYTHh4AmBzyDpaNGzfGDTfcEDfeeGMsWLAg7rnnnqiqqootW7aMuf2vf/3rOO+882L16tVRXV0dl19+eXz605+OPXv2nPDwAMDkkFewHDlyJNrb26O+vn7Een19fezevXvMferq6uKPf/xjtLa2RpZl8ac//Sm+//3vx7Jly455PwMDA9Hf3z/iAgBMXnkFS29vbwwODkZlZeWI9crKyuju7h5zn7q6unjooYdixYoVMW3atJg9e3acddZZ8Y1vfOOY99Pc3BwVFRXDl6qqqnzGBACKzLg+dFtSUjLiepZlo9Zes2/fvli9enV8/vOfj/b29vj5z38eBw4ciIaGhmPe/rp166Kvr2/4cvDgwfGMCQAUian5bDxr1qwoLS0ddTSlp6dn1FGX1zQ3N8fixYvj1ltvjYiId77znfGmN70plixZEl/84hdjzpw5o/bJ5XKRy+XyGQ0AKGJ5HWGZNm1a1NTURFtb24j1tra2qKurG3OfV155JaZMGXk3paWlEfHqkRkAgDeS91tCTU1Ncf/998f27dtj//79sWbNmujs7Bx+i2fdunWxcuXK4e2XL18ejzzySGzZsiWee+65eOKJJ2L16tVx6aWXxty5cyfukQAARSuvt4QiIlasWBGHDh2KDRs2RFdXVyxcuDBaW1tj/vz5ERHR1dU14pwsn/zkJ+Pw4cOxadOm+Kd/+qc466yz4v3vf398+ctfnrhHAQAUtZLsNHhfpr+/PyoqKqKvry/Ky8sLPQ4k7by1Py30CHn7w13HPs0BcPqayNdvvyUEACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRvXMGyefPmqK6ujrKysqipqYldu3a97vYDAwOxfv36mD9/fuRyufjrv/7r2L59+7gGBgAmn6n57tDS0hKNjY2xefPmWLx4cdx3332xdOnS2LdvX8ybN2/Mfa6++ur405/+FA888ED8zd/8TfT09MTRo0dPeHgAYHIoybIsy2eHyy67LC655JLYsmXL8NqCBQviqquuiubm5lHb//znP4+Pfexj8dxzz8XZZ589riH7+/ujoqIi+vr6ory8fFy3AZPFeWt/WugR8vaHu5YVegTgJJjI1++83hI6cuRItLe3R319/Yj1+vr62L1795j7PProo1FbWxt33313vOUtb4m3vvWtccstt8R///d/H/N+BgYGor+/f8QFAJi88npLqLe3NwYHB6OysnLEemVlZXR3d4+5z3PPPRePP/54lJWVxQ9/+MPo7e2Nz3zmM/HnP//5mJ9jaW5ujjvuuCOf0QCAIjauD92WlJSMuJ5l2ai11wwNDUVJSUk89NBDcemll8aVV14ZGzdujAcffPCYR1nWrVsXfX19w5eDBw+OZ0wAoEjkdYRl1qxZUVpaOupoSk9Pz6ijLq+ZM2dOvOUtb4mKiorhtQULFkSWZfHHP/4xLrjgglH75HK5yOVy+YwGABSxvI6wTJs2LWpqaqKtrW3EeltbW9TV1Y25z+LFi+PFF1+Ml156aXjt2WefjSlTpsS55547jpEBgMkm77eEmpqa4v7774/t27fH/v37Y82aNdHZ2RkNDQ0R8erbOStXrhze/pprromZM2fGpz71qdi3b1/s3Lkzbr311rj++utj+vTpE/dIAICilfd5WFasWBGHDh2KDRs2RFdXVyxcuDBaW1tj/vz5ERHR1dUVnZ2dw9v/1V/9VbS1tcVnP/vZqK2tjZkzZ8bVV18dX/ziFyfuUQAARS3v87AUgvOwwPFzHhYgFQU7DwsAQCEIFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOSNK1g2b94c1dXVUVZWFjU1NbFr167j2u+JJ56IqVOnxrve9a7x3C0AMEnlHSwtLS3R2NgY69evj71798aSJUti6dKl0dnZ+br79fX1xcqVK+MDH/jAuIcFACanvINl48aNccMNN8SNN94YCxYsiHvuuSeqqqpiy5Ytr7vfpz/96bjmmmti0aJF4x4WAJic8gqWI0eORHt7e9TX149Yr6+vj927dx9zv29+85vxn//5n3H77bcf1/0MDAxEf3//iAsAMHnlFSy9vb0xODgYlZWVI9YrKyuju7t7zH1+//vfx9q1a+Ohhx6KqVOnHtf9NDc3R0VFxfClqqoqnzEBgCIzrg/dlpSUjLieZdmotYiIwcHBuOaaa+KOO+6It771rcd9++vWrYu+vr7hy8GDB8czJgBQJI7vkMf/N2vWrCgtLR11NKWnp2fUUZeIiMOHD8eePXti7969cdNNN0VExNDQUGRZFlOnTo3HHnss3v/+94/aL5fLRS6Xy2c0AKCI5XWEZdq0aVFTUxNtbW0j1tva2qKurm7U9uXl5fH0009HR0fH8KWhoSHe9ra3RUdHR1x22WUnNj0AMCnkdYQlIqKpqSmuvfbaqK2tjUWLFsW2bduis7MzGhoaIuLVt3NeeOGF+Na3vhVTpkyJhQsXjtj/nHPOibKyslHrAADHknewrFixIg4dOhQbNmyIrq6uWLhwYbS2tsb8+fMjIqKrq+sNz8kCAJCPkizLskIP8Ub6+/ujoqIi+vr6ory8vNDjQNLOW/vTQo+Qtz/ctazQIwAnwUS+fvstIQAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5AkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeeMKls2bN0d1dXWUlZVFTU1N7Nq165jbPvLII3HFFVfEm9/85igvL49FixbFv/3bv417YABg8sk7WFpaWqKxsTHWr18fe/fujSVLlsTSpUujs7NzzO137twZV1xxRbS2tkZ7e3u8733vi+XLl8fevXtPeHgAYHIoybIsy2eHyy67LC655JLYsmXL8NqCBQviqquuiubm5uO6jYsuuihWrFgRn//8549r+/7+/qioqIi+vr4oLy/PZ1yYdM5b+9NCj5C3P9y1rNAjACfBRL5+53WE5ciRI9He3h719fUj1uvr62P37t3HdRtDQ0Nx+PDhOPvss4+5zcDAQPT394+4AACTV17B0tvbG4ODg1FZWTlivbKyMrq7u4/rNr72ta/Fyy+/HFdfffUxt2lubo6KiorhS1VVVT5jAgBFZlwfui0pKRlxPcuyUWtj+c53vhNf+MIXoqWlJc4555xjbrdu3bro6+sbvhw8eHA8YwIARWJqPhvPmjUrSktLRx1N6enpGXXU5f9qaWmJG264Ib73ve/FBz/4wdfdNpfLRS6Xy2c0AKCI5XWEZdq0aVFTUxNtbW0j1tva2qKuru6Y+33nO9+JT37yk/Hwww/HsmU+XAcA5CevIywREU1NTXHttddGbW1tLFq0KLZt2xadnZ3R0NAQEa++nfPCCy/Et771rYh4NVZWrlwZ//Iv/xJ/93d/N3x0Zvr06VFRUTGBDwUAKFZ5B8uKFSvi0KFDsWHDhujq6oqFCxdGa2trzJ8/PyIiurq6RpyT5b777oujR4/GqlWrYtWqVcPr1113XTz44IMn/ggAgKKX93lYCsF5WOD4OQ8LkIqCnYcFAKAQBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyBAsAkDzBAgAkT7AAAMkTLABA8gQLAJA8wQIAJE+wAADJEywAQPIECwCQPMECACRPsAAAyRMsAEDyxhUsmzdvjurq6igrK4uamprYtWvX626/Y8eOqKmpibKysjj//PNj69at4xoWAJicpua7Q0tLSzQ2NsbmzZtj8eLFcd9998XSpUtj3759MW/evFHbHzhwIK688sr4x3/8x/j2t78dTzzxRHzmM5+JN7/5zfGRj3xkQh7EZHPe2p8WeoRx+cNdywo9AgCnqZIsy7J8drjsssvikksuiS1btgyvLViwIK666qpobm4etf1tt90Wjz76aOzfv394raGhIX7729/Gk08+eVz32d/fHxUVFdHX1xfl5eX5jPuGTtcX/9ORYDk1/E2fGv6e4Y1N5Ot3XkdYjhw5Eu3t7bF27doR6/X19bF79+4x93nyySejvr5+xNqHPvSheOCBB+Ivf/lLnHHGGaP2GRgYiIGBgeHrfX19EfHqA59oQwOvTPhtMraT8e+P0fxNnxr+nuGNvfbfSZ7HRsaUV7D09vbG4OBgVFZWjlivrKyM7u7uMffp7u4ec/ujR49Gb29vzJkzZ9Q+zc3Ncccdd4xar6qqymdcElNxT6EngInj7xmO36FDh6KiouKEbiPvz7BERJSUlIy4nmXZqLU32n6s9desW7cumpqahq8PDQ3Fn//855g5c+br3k9K+vv7o6qqKg4ePDjhb2PxvzzPp4bn+dTwPJ86nutTo6+vL+bNmxdnn332Cd9WXsEya9asKC0tHXU0paenZ9RRlNfMnj17zO2nTp0aM2fOHHOfXC4XuVxuxNpZZ52Vz6jJKC8v9x/DKeB5PjU8z6eG5/nU8VyfGlOmnPhZVPK6hWnTpkVNTU20tbWNWG9ra4u6urox91m0aNGo7R977LGora0d8/MrAAD/V97J09TUFPfff39s37499u/fH2vWrInOzs5oaGiIiFffzlm5cuXw9g0NDfH8889HU1NT7N+/P7Zv3x4PPPBA3HLLLRP3KACAopb3Z1hWrFgRhw4dig0bNkRXV1csXLgwWltbY/78+RER0dXVFZ2dncPbV1dXR2tra6xZsybuvffemDt3bnz9618v+nOw5HK5uP3220e9tcXE8jyfGp7nU8PzfOp4rk+NiXye8z4PCwDAqea3hACA5AkWACB5ggUASJ5gAQCSJ1hOgs2bN0d1dXWUlZVFTU1N7Nq1q9AjFZ3m5ub427/925gxY0acc845cdVVV8Xvfve7Qo9V9Jqbm6OkpCQaGxsLPUrReeGFF+ITn/hEzJw5M84888x417veFe3t7YUeq6gcPXo0/vmf/zmqq6tj+vTpcf7558eGDRtiaGio0KOd9nbu3BnLly+PuXPnRklJSfzoRz8a8c+zLIsvfOELMXfu3Jg+fXq8973vjWeeeSav+xAsE6ylpSUaGxtj/fr1sXfv3liyZEksXbp0xFe9OXE7duyIVatWxa9//etoa2uLo0ePRn19fbz88suFHq1oPfXUU7Ft27Z45zvfWehRis5//dd/xeLFi+OMM86In/3sZ7Fv37742te+dtqe4TtVX/7yl2Pr1q2xadOm2L9/f9x9993xla98Jb7xjW8UerTT3ssvvxwXX3xxbNq0acx/fvfdd8fGjRtj06ZN8dRTT8Xs2bPjiiuuiMOHDx//nWRMqEsvvTRraGgYsXbhhRdma9euLdBEk0NPT08WEdmOHTsKPUpROnz4cHbBBRdkbW1t2Xve857s5ptvLvRIReW2227LLr/88kKPUfSWLVuWXX/99SPWPvzhD2ef+MQnCjRRcYqI7Ic//OHw9aGhoWz27NnZXXfdNbz2P//zP1lFRUW2devW475dR1gm0JEjR6K9vT3q6+tHrNfX18fu3bsLNNXk0NfXFxExIT+wxWirVq2KZcuWxQc/+MFCj1KUHn300aitrY2PfvSjcc4558S73/3u+Nd//ddCj1V0Lr/88vjFL34Rzz77bERE/Pa3v43HH388rrzyygJPVtwOHDgQ3d3dI14bc7lcvOc978nrtXFcv9bM2Hp7e2NwcHDUD0FWVlaO+gFIJk6WZdHU1BSXX355LFy4sNDjFJ3vfve70d7eHnv27Cn0KEXrueeeiy1btkRTU1N87nOfi9/85jexevXqyOVyI37qhBNz2223RV9fX1x44YVRWloag4OD8aUvfSn+4R/+odCjFbXXXv/Gem18/vnnj/t2BMtJUFJSMuJ6lmWj1pg4N910U/z7v/97PP7444UepegcPHgwbr755njssceirKys0OMUraGhoaitrY0777wzIiLe/e53xzPPPBNbtmwRLBOopaUlvv3tb8fDDz8cF110UXR0dERjY2PMnTs3rrvuukKPV/RO9LVRsEygWbNmRWlp6aijKT09PaPKkonx2c9+Nh599NHYuXNnnHvuuYUep+i0t7dHT09P1NTUDK8NDg7Gzp07Y9OmTTEwMBClpaUFnLA4zJkzJ97+9rePWFuwYEH84Ac/KNBExenWW2+NtWvXxsc+9rGIiHjHO94Rzz//fDQ3NwuWk2j27NkR8eqRljlz5gyv5/va6DMsE2jatGlRU1MTbW1tI9bb2tqirq6uQFMVpyzL4qabbopHHnkkfvnLX0Z1dXWhRypKH/jAB+Lpp5+Ojo6O4UttbW18/OMfj46ODrEyQRYvXjzqa/nPPvvs8I/KMjFeeeWVmDJl5MteaWmprzWfZNXV1TF79uwRr41HjhyJHTt25PXa6AjLBGtqaoprr702amtrY9GiRbFt27bo7OyMhoaGQo9WVFatWhUPP/xw/PjHP44ZM2YMH9WqqKiI6dOnF3i64jFjxoxRnwt605veFDNnzvR5oQm0Zs2aqKurizvvvDOuvvrq+M1vfhPbtm2Lbdu2FXq0orJ8+fL40pe+FPPmzYuLLroo9u7dGxs3bozrr7++0KOd9l566aX4j//4j+HrBw4ciI6Ojjj77LNj3rx50djYGHfeeWdccMEFccEFF8Sdd94ZZ555ZlxzzTXHfycT9TUm/te9996bzZ8/P5s2bVp2ySWX+KrtSRARY16++c1vFnq0oudrzSfHT37yk2zhwoVZLpfLLrzwwmzbtm2FHqno9Pf3ZzfffHM2b968rKysLDv//POz9evXZwMDA4Ue7bT3q1/9asz/J1933XVZlr361ebbb789mz17dpbL5bK///u/z55++um87qMky7JsogoLAOBk8BkWACB5ggUASJ5gAQCSJ1gAgOQJFgAgeYIFAEieYAEAkidYAIDkCRYAIHmCBQBInmABAJInWACA5P0/rvXgvAllcjMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y_tr_idx2 = dataset_labels[:, :].argmax(axis=1)\n",
    "class_density = plt.hist(y_tr_idx2, bins=10, range=(-0.5, 9.5), density=True)[0]\n",
    "# class_weights = 1 / class_density\n",
    "# class_weights = class_weights / np.max(class_weights)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "7cda832a-9109-4171-8cc4-ce7e2a0a0857",
   "metadata": {},
   "source": [
    "y_tr_idx2 = test_dataset[:, :].argmax(axis=1)\n",
    "class_density2 = plt.hist(y_tr_idx2, bins=10, range=(-0.5, 9.5), density=False)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77408305-d86d-4e43-9ea6-34a133cb3739",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f28c5ee6-1cb0-4776-813d-19f5e6bfd13b",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_layer = keras.Input((117,))\n",
    "H1 = layers.Dense(32, activation='sigmoid')\n",
    "H2 = layers.Dense(32, activation='sigmoid')\n",
    "H3 = layers.Dense(32, activation='sigmoid')\n",
    "H4 = layers.Dense(32, activation='sigmoid')\n",
    "supervised_layer = layers.Dense(5, activation='sigmoid')\n",
    "output_layer = layers.Dense(5, activation='softmax')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09adcd12-291c-42cd-b523-ee00494c883c",
   "metadata": {},
   "source": [
    "### AutoEncoder 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a37442de-05a5-434e-b3f5-c0a1ad6cb413",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 117)]             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 32)                3776      \n",
      "                                                                 \n",
      " dense_6 (Dense)             (None, 117)               3861      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 7,637\n",
      "Trainable params: 7,637\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "encoded1 = H1(input_layer)\n",
    "decoded1 = layers.Dense(117,)(encoded1)\n",
    "auto_encoder1 = keras.Model(input_layer, decoded1)\n",
    "auto_encoder1.compile(loss=\"mse\", metrics=[\"mse\", 'mae'], weighted_metrics=[\"mse\", 'mae'])\n",
    "auto_encoder1.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5b0b20ec-67f2-457a-8aac-9a4709d965b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "only_encoded1 = keras.Model(input_layer, encoded1)\n",
    "only_encoded1.compile(loss=\"mse\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55029793-57a4-490b-aa05-f29b925cc75c",
   "metadata": {},
   "source": [
    "### AutoEncoder 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5b28ccad-cabc-49ed-8847-22c9afc3f784",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded2 = H2(encoded1)\n",
    "decoded2 = layers.Dense(32,)(encoded2)\n",
    "auto_encoder2 = keras.Model(encoded1, decoded2)\n",
    "auto_encoder2.compile(loss=\"mse\", metrics=[\"mse\", 'mae'], weighted_metrics=[\"mse\", 'mae'])\n",
    "# auto_encoder2.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "3ab6c825-b86a-4970-9e2e-29bc7c0defe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "only_encoded2 = keras.Model(input_layer, encoded2)\n",
    "only_encoded2.compile(loss=\"mse\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de573b6f-5533-4642-9f6b-6579c4f086e0",
   "metadata": {},
   "source": [
    "### AutoEncoder 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "7d314c4f-989d-413a-a558-7367c5140c96",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded3 = H3(encoded2)\n",
    "decoded3 = layers.Dense(32,)(encoded3)\n",
    "auto_encoder3 = keras.Model(encoded2, decoded3)\n",
    "auto_encoder3.compile(loss=\"mse\", metrics=[\"mse\", 'mae'], weighted_metrics=[\"mse\", 'mae'])\n",
    "# auto_encoder3.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "cec37065-7412-482d-91a8-bb5b46e12117",
   "metadata": {},
   "outputs": [],
   "source": [
    "only_encoded3 = keras.Model(input_layer, encoded3)\n",
    "only_encoded3.compile(loss=\"mse\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3f57b10-9ce4-4042-bc3f-1ebb6ef6acf0",
   "metadata": {},
   "source": [
    "### AutoEncoder 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "4fceb81b-8f24-4a92-8b0b-77b89e1c7f4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded4 = H4(encoded3)\n",
    "decoded4 = layers.Dense(32,)(encoded4)\n",
    "auto_encoder4 = keras.Model(encoded3, decoded4)\n",
    "auto_encoder4.compile(loss=\"mse\", metrics=[\"mse\", 'mae'], weighted_metrics=[\"mse\", 'mae'])\n",
    "# auto_encoder4.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "64d647f1-be36-4d78-8a15-bb60887c3165",
   "metadata": {},
   "outputs": [],
   "source": [
    "only_encoded4 = keras.Model(input_layer, encoded4)\n",
    "only_encoded4.compile(loss=\"mse\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3be5c7d8-db6c-4513-b51b-41695c3defb5",
   "metadata": {},
   "source": [
    "### Supervised Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "262a8dc4-a6ae-4024-8f06-49c299b9655c",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_ = supervised_layer(encoded4)\n",
    "out__ = output_layer(temp_)\n",
    "supervised_part = keras.Model(encoded4, out__)\n",
    "supervised_part.compile(loss='categorical_crossentropy', weighted_metrics=['accuracy', 'categorical_crossentropy'],\n",
    "                         metrics=['accuracy', 'categorical_crossentropy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffecb4ec-17ce-4261-9e67-5a7bff54f3cc",
   "metadata": {},
   "source": [
    "### Full Assembly of the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "82049bb2-f67f-427d-b968-ef9ebc47f8d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = supervised_layer(encoded4)\n",
    "out_ = output_layer(temp)\n",
    "final_supervised = keras.Model(input_layer, out_)\n",
    "final_supervised.compile(loss='categorical_crossentropy', weighted_metrics=['accuracy', 'categorical_crossentropy'],\n",
    "                         metrics=['accuracy', 'categorical_crossentropy', tf.keras.metrics.TruePositives(\n",
    "    thresholds=None, name=None, dtype=None\n",
    "), tf.keras.metrics.FalseNegatives(\n",
    "    thresholds=None, name=None, dtype=None\n",
    "), tf.keras.metrics.TrueNegatives(\n",
    "    thresholds=None, name=None, dtype=None\n",
    "), tf.keras.metrics.FalsePositives(\n",
    "    thresholds=None, name=None, dtype=None\n",
    ")])\n",
    "K.set_value(final_supervised.optimizer.learning_rate, 0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbc18db9-22c4-4482-9c6f-8342be2fb06a",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "raw",
   "id": "ef69193b-3f31-429e-83a6-217e5f6e53fd",
   "metadata": {},
   "source": [
    "auto_encoder1 = tf.keras.models.load_model('./auto_encoder_1_')"
   ]
  },
  {
   "cell_type": "raw",
   "id": "45021773-ec17-4677-9186-df6d7bfb2994",
   "metadata": {},
   "source": [
    "my_sample_weights = class_weights[y_tr_idx[np.arange(dataset_without_label.shape[0])]]"
   ]
  },
  {
   "cell_type": "raw",
   "id": "f6aca775-0682-466c-a98c-c9f7ba7044fb",
   "metadata": {},
   "source": [
    "auto_encoder1.fit(dataset_without_label, dataset_without_label, epochs=25, batch_size=100,\n",
    "            shuffle=False, validation_data=None, sample_weight=my_sample_weights,\n",
    "            callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])"
   ]
  },
  {
   "cell_type": "raw",
   "id": "25c28d38-6bf2-4963-954a-8e5e35abbd2f",
   "metadata": {},
   "source": [
    "auto_encoder1.save(\"./auto_encoder_1_\")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "1241f2f6-433c-48f2-981e-a284ceb8f355",
   "metadata": {},
   "source": [
    "auto_encoder2 = tf.keras.models.load_model('./auto_encoder_2_')"
   ]
  },
  {
   "cell_type": "raw",
   "id": "0727eb21-a1e8-4461-9fe9-f08cce6a8309",
   "metadata": {},
   "source": [
    "auto_encoder2.fit(only_encoded1(dataset_without_label), only_encoded1(dataset_without_label), epochs=3, batch_size=100,\n",
    "            shuffle=False, validation_data=None, sample_weight=my_sample_weights,\n",
    "            callbacks=[TensorBoard(log_dir='/tmp/autoencoder2')])"
   ]
  },
  {
   "cell_type": "raw",
   "id": "749361d2-49dc-43bd-8a68-4d17ec3c6bb8",
   "metadata": {},
   "source": [
    "auto_encoder2.save(\"./auto_encoder_2_\")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "8f53520c-9131-43dc-9655-509c2bdb224f",
   "metadata": {},
   "source": [
    "auto_encoder3 = tf.keras.models.load_model('./auto_encoder_3_')"
   ]
  },
  {
   "cell_type": "raw",
   "id": "f2f0f0aa-6fca-42c0-98d7-62ad6e421fac",
   "metadata": {},
   "source": [
    "auto_encoder3.fit(only_encoded2(dataset_without_label), only_encoded2(dataset_without_label), epochs=3, batch_size=100,\n",
    "            shuffle=False, validation_data=None, sample_weight=my_sample_weights,\n",
    "            callbacks=[TensorBoard(log_dir='/tmp/autoencoder3')])"
   ]
  },
  {
   "cell_type": "raw",
   "id": "b7648fa3-cc96-4d17-8a0d-78c3c1796987",
   "metadata": {},
   "source": [
    "auto_encoder3.save(\"./auto_encoder_3_\")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "7a4a47ea-8d2b-420c-a582-42d4c432233b",
   "metadata": {},
   "source": [
    "auto_encoder4 = tf.keras.models.load_model('./auto_encoder_4_')"
   ]
  },
  {
   "cell_type": "raw",
   "id": "6477f63c-90bb-4b1b-9429-5045c640526c",
   "metadata": {},
   "source": [
    "auto_encoder4.fit(only_encoded3(dataset_without_label), only_encoded3(dataset_without_label), epochs=3, batch_size=100,\n",
    "            shuffle=False, validation_data=None, sample_weight=my_sample_weights,\n",
    "            callbacks=[TensorBoard(log_dir='/tmp/autoencoder4')])"
   ]
  },
  {
   "cell_type": "raw",
   "id": "c82aa4e1-c2c0-4ceb-a9c6-daaca1fcee3c",
   "metadata": {},
   "source": [
    "auto_encoder4.save(\"./auto_encoder_4_\")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "c8e13854-74e7-48f6-8abf-751feeab6223",
   "metadata": {},
   "source": [
    "supervised_part = tf.keras.models.load_model('./supervised_part')"
   ]
  },
  {
   "cell_type": "raw",
   "id": "827ca00b-dbc6-408f-a271-47f8b0cfbf4e",
   "metadata": {},
   "source": [
    "supervised_part.fit(only_encoded4(dataset_without_label), dataset_labels[:, :10],\n",
    "            epochs=200, batch_size=100,\n",
    "            shuffle=True, validation_data=None, class_weight={i:class_weights[i] for i in range(10)},\n",
    "            callbacks=[TensorBoard(log_dir='/tmp/final_supervised')])"
   ]
  },
  {
   "cell_type": "raw",
   "id": "e0ac8efe-c4d1-4312-b4c0-64739e194d2b",
   "metadata": {},
   "source": [
    "supervised_part.save(\"./supervised_part\")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "1db5a2f8-bf3c-4420-9166-f6e5994d009e",
   "metadata": {},
   "source": [
    "final_supervised = tf.keras.models.load_model('./final_supervised_')"
   ]
  },
  {
   "cell_type": "raw",
   "id": "d43664aa-66d6-4c50-9b84-631a2dc83d5c",
   "metadata": {},
   "source": [
    "class_weight = {i: class_weights[i] for i in range(10)}\n",
    "# class_weight = {i: 1 for i in range(10)}\n",
    "class_weight"
   ]
  },
  {
   "cell_type": "raw",
   "id": "f4d6720d-e0f4-48fe-af9a-45dc691bd0cf",
   "metadata": {},
   "source": [
    "without_last = np.concatenate((dataset_labels[:, :9], np.zeros((dataset_labels.shape[0], 1))), axis=1)\n",
    "without_last.shape"
   ]
  },
  {
   "cell_type": "raw",
   "id": "bdd81e42-5f40-4ce7-9cff-dc269bc69653",
   "metadata": {},
   "source": [
    "for i in range(100):\n",
    "    results = final_supervised.evaluate(dataset_without_label, dataset_labels[:, :5], batch_size=128)[1]\n",
    "    if (results >= 0.94):\n",
    "        break\n",
    "    final_supervised.fit(dataset_without_label, dataset_labels[:, :5],\n",
    "            epochs=1, batch_size=128,\n",
    "            shuffle=False, validation_data=None, class_weight={0: 0.9, 1: 0.9, 2: 1, 3: 0.1, 4: 0.9},\n",
    "            callbacks=[TensorBoard(log_dir='/tmp/final_supervised')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "3f03578b-3d40-4f80-89df-52b99d3f3477",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2688/2688 [==============================] - 3s 846us/step - loss: 1.9273 - accuracy: 0.0000e+00 - categorical_crossentropy: 1.9273 - true_positives: 0.0000e+00 - false_negatives: 342096.0000 - true_negatives: 1377629.0000 - false_positives: 0.0000e+00 - weighted_accuracy: 0.0000e+00 - weighted_categorical_crossentropy: 1.9273\n",
      "2688/2688 [==============================] - 4s 1ms/step - loss: 0.2148 - accuracy: 0.9485 - categorical_crossentropy: 0.2148 - true_positives: 310373.0000 - false_negatives: 31723.0000 - true_negatives: 1371705.0000 - false_positives: 5924.0000 - weighted_accuracy: 0.9485 - weighted_categorical_crossentropy: 0.2148\n",
      "2688/2688 [==============================] - 2s 853us/step - loss: 0.1381 - accuracy: 0.9769 - categorical_crossentropy: 0.1381 - true_positives: 335986.0000 - false_negatives: 6110.0000 - true_negatives: 1369670.0000 - false_positives: 7959.0000 - weighted_accuracy: 0.9769 - weighted_categorical_crossentropy: 0.1381\n",
      "2688/2688 [==============================] - 3s 1ms/step - loss: 0.1320 - accuracy: 0.9769 - categorical_crossentropy: 0.1320 - true_positives: 335986.0000 - false_negatives: 6110.0000 - true_negatives: 1369670.0000 - false_positives: 7959.0000 - weighted_accuracy: 0.9769 - weighted_categorical_crossentropy: 0.1320\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_18696\\1846597461.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m100\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m     \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfinal_supervised\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdataset_without_label\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdataset_labels\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;36m5\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m128\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      3\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mresults\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;36m0.99\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m         \u001b[1;32mbreak\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     final_supervised.fit(dataset_without_label, dataset_labels[:, :5],\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\keras\\utils\\traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     63\u001b[0m         \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     64\u001b[0m         \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 65\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     66\u001b[0m         \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     67\u001b[0m             \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mevaluate\u001b[1;34m(self, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, return_dict, **kwargs)\u001b[0m\n\u001b[0;32m   1999\u001b[0m                 \u001b[1;31m# Creates a `tf.data.Dataset` and handles batch and epoch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2000\u001b[0m                 \u001b[1;31m# iteration.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2001\u001b[1;33m                 data_handler = data_adapter.get_data_handler(\n\u001b[0m\u001b[0;32m   2002\u001b[0m                     \u001b[0mx\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2003\u001b[0m                     \u001b[0my\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\keras\\engine\\data_adapter.py\u001b[0m in \u001b[0;36mget_data_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m   1577\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"model\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"_cluster_coordinator\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1578\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0m_ClusterCoordinatorDataHandler\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1579\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mDataHandler\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1580\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1581\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\keras\\engine\\data_adapter.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, x, y, sample_weight, batch_size, steps_per_epoch, initial_epoch, epochs, shuffle, class_weight, max_queue_size, workers, use_multiprocessing, model, steps_per_execution, distribute)\u001b[0m\n\u001b[0;32m   1257\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1258\u001b[0m         \u001b[0madapter_cls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mselect_data_adapter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1259\u001b[1;33m         self._adapter = adapter_cls(\n\u001b[0m\u001b[0;32m   1260\u001b[0m             \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1261\u001b[0m             \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\keras\\engine\\data_adapter.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, x, y, sample_weights, sample_weight_modes, batch_size, epochs, steps, shuffle, **kwargs)\u001b[0m\n\u001b[0;32m    243\u001b[0m     ):\n\u001b[0;32m    244\u001b[0m         \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 245\u001b[1;33m         \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weights\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_process_tensorlike\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weights\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    246\u001b[0m         sample_weight_modes = broadcast_sample_weight_modes(\n\u001b[0;32m    247\u001b[0m             \u001b[0msample_weights\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weight_modes\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\keras\\engine\\data_adapter.py\u001b[0m in \u001b[0;36m_process_tensorlike\u001b[1;34m(inputs)\u001b[0m\n\u001b[0;32m   1137\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1138\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1139\u001b[1;33m     \u001b[0minputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnest\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmap_structure\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_convert_single_tensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1140\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__internal__\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnest\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlist_to_tuple\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1141\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\util\\nest.py\u001b[0m in \u001b[0;36mmap_structure\u001b[1;34m(func, *structure, **kwargs)\u001b[0m\n\u001b[0;32m    915\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    916\u001b[0m   return pack_sequence_as(\n\u001b[1;32m--> 917\u001b[1;33m       \u001b[0mstructure\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mentries\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    918\u001b[0m       expand_composites=expand_composites)\n\u001b[0;32m    919\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\util\\nest.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m    915\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    916\u001b[0m   return pack_sequence_as(\n\u001b[1;32m--> 917\u001b[1;33m       \u001b[0mstructure\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mentries\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    918\u001b[0m       expand_composites=expand_composites)\n\u001b[0;32m    919\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\keras\\engine\\data_adapter.py\u001b[0m in \u001b[0;36m_convert_single_tensor\u001b[1;34m(x)\u001b[0m\n\u001b[0;32m   1132\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0missubclass\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloating\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1133\u001b[0m                 \u001b[0mdtype\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbackend\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloatx\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1134\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconvert_to_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1135\u001b[0m         \u001b[1;32melif\u001b[0m \u001b[0m_is_scipy_sparse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1136\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0m_scipy_sparse_to_sparse_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\util\\traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    148\u001b[0m     \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    149\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 150\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    151\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    152\u001b[0m       \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\util\\dispatch.py\u001b[0m in \u001b[0;36mop_dispatch_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m   1174\u001b[0m       \u001b[1;31m# Fallback dispatch system (dispatch v1):\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1175\u001b[0m       \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1176\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mdispatch_target\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1177\u001b[0m       \u001b[1;32mexcept\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1178\u001b[0m         \u001b[1;31m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\u001b[0m in \u001b[0;36mconvert_to_tensor_v2_with_dispatch\u001b[1;34m(value, dtype, dtype_hint, name)\u001b[0m\n\u001b[0;32m   1488\u001b[0m     \u001b[0mValueError\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mIf\u001b[0m \u001b[0mthe\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;31m`\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0ma\u001b[0m \u001b[0mtensor\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mof\u001b[0m \u001b[0mgiven\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;31m`\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mgraph\u001b[0m \u001b[0mmode\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1489\u001b[0m   \"\"\"\n\u001b[1;32m-> 1490\u001b[1;33m   return convert_to_tensor_v2(\n\u001b[0m\u001b[0;32m   1491\u001b[0m       value, dtype=dtype, dtype_hint=dtype_hint, name=name)\n\u001b[0;32m   1492\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\u001b[0m in \u001b[0;36mconvert_to_tensor_v2\u001b[1;34m(value, dtype, dtype_hint, name)\u001b[0m\n\u001b[0;32m   1494\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mconvert_to_tensor_v2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype_hint\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1495\u001b[0m   \u001b[1;34m\"\"\"Converts the given `value` to a `Tensor`.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1496\u001b[1;33m   return convert_to_tensor(\n\u001b[0m\u001b[0;32m   1497\u001b[0m       \u001b[0mvalue\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1498\u001b[0m       \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\profiler\\trace.py\u001b[0m in \u001b[0;36mwrapped\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    181\u001b[0m         \u001b[1;32mwith\u001b[0m \u001b[0mTrace\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrace_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mtrace_kwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    182\u001b[0m           \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 183\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    184\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    185\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\u001b[0m in \u001b[0;36mconvert_to_tensor\u001b[1;34m(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)\u001b[0m\n\u001b[0;32m   1634\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1635\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1636\u001b[1;33m       \u001b[0mret\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconversion_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mas_ref\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mas_ref\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1637\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1638\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\tensor_conversion_registry.py\u001b[0m in \u001b[0;36m_default_conversion_function\u001b[1;34m(***failed resolving arguments***)\u001b[0m\n\u001b[0;32m     46\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_default_conversion_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mas_ref\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     47\u001b[0m   \u001b[1;32mdel\u001b[0m \u001b[0mas_ref\u001b[0m  \u001b[1;31m# Unused.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 48\u001b[1;33m   \u001b[1;32mreturn\u001b[0m \u001b[0mconstant_op\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconstant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     49\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     50\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\constant_op.py\u001b[0m in \u001b[0;36mconstant\u001b[1;34m(value, dtype, shape, name)\u001b[0m\n\u001b[0;32m    265\u001b[0m     \u001b[0mValueError\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcalled\u001b[0m \u001b[0mon\u001b[0m \u001b[0ma\u001b[0m \u001b[0msymbolic\u001b[0m \u001b[0mtensor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    266\u001b[0m   \"\"\"\n\u001b[1;32m--> 267\u001b[1;33m   return _constant_impl(value, dtype, shape, name, verify_shape=False,\n\u001b[0m\u001b[0;32m    268\u001b[0m                         allow_broadcast=True)\n\u001b[0;32m    269\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\constant_op.py\u001b[0m in \u001b[0;36m_constant_impl\u001b[1;34m(value, dtype, shape, name, verify_shape, allow_broadcast)\u001b[0m\n\u001b[0;32m    277\u001b[0m       \u001b[1;32mwith\u001b[0m \u001b[0mtrace\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTrace\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"tf.constant\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    278\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0m_constant_eager_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverify_shape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 279\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0m_constant_eager_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverify_shape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    280\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    281\u001b[0m   \u001b[0mg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mops\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_default_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\constant_op.py\u001b[0m in \u001b[0;36m_constant_eager_impl\u001b[1;34m(ctx, value, dtype, shape, verify_shape)\u001b[0m\n\u001b[0;32m    302\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_constant_eager_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverify_shape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    303\u001b[0m   \u001b[1;34m\"\"\"Creates a constant on the current device.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 304\u001b[1;33m   \u001b[0mt\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconvert_to_eager_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    305\u001b[0m   \u001b[1;32mif\u001b[0m \u001b[0mshape\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    306\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\constant_op.py\u001b[0m in \u001b[0;36mconvert_to_eager_tensor\u001b[1;34m(value, ctx, dtype)\u001b[0m\n\u001b[0;32m    100\u001b[0m       \u001b[0mdtype\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_dtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_datatype_enum\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    101\u001b[0m   \u001b[0mctx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m   \u001b[1;32mreturn\u001b[0m \u001b[0mops\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mEagerTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    103\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    104\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i in range(100):\n",
    "    results = final_supervised.evaluate(dataset_without_label, dataset_labels[:, :5], batch_size=128)[1]\n",
    "    if (results >= 0.99):\n",
    "        break\n",
    "    final_supervised.fit(dataset_without_label, dataset_labels[:, :5],\n",
    "            epochs=1, batch_size=128,\n",
    "            shuffle=False, validation_data=None,\n",
    "            callbacks=[TensorBoard(log_dir='/tmp/final_supervised')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5993b5d4-b6dc-4789-bca4-3eb350904c85",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_supervised.save(\"./final_supervised_\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fedfa4cf-e66d-42f0-b3ad-e41ae7728aad",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc39114-ad57-4b2d-b8ec-88df54b534f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset_without_labels = test_dataset[:, :117]\n",
    "test_dataset_labels = test_dataset[:, 117:]\n",
    "print(\"Evaluate on test data\")\n",
    "results = final_supervised.evaluate(test_dataset_without_labels, test_dataset_labels[:, :5], batch_size=12)\n",
    "print(\"test loss, test acc:\", results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "741ba3c6-bfa6-4a3a-bc10-564e9fd1e967",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Evaluate on train data\")\n",
    "results = final_supervised.evaluate(dataset_without_label, dataset_labels[:, :5], batch_size=128)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fff94709-1fbc-417d-a6da-bc2f32d84877",
   "metadata": {},
   "source": [
    "# Confusion Matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "465748d1-0964-4b7d-b71f-8dd18f81a1e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat_tr = final_supervised(dataset_without_label)\n",
    "y_hat_tr_idx = np.array(y_hat_tr).argmax(axis=1)\n",
    "y_tr_idx = dataset_labels.argmax(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "1cccaeef-f0d8-465d-a2d6-769465776481",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "343945\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 1375/1375 [00:04<00:00, 303.95it/s]\n"
     ]
    }
   ],
   "source": [
    "conf = np.zeros((10, 10), dtype=np.int32)\n",
    "print(y_tr_idx.shape[0])\n",
    "for i in tqdm(range(y_tr_idx.shape[0] // 250)):\n",
    "    conf[y_tr_idx, y_hat_tr_idx] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "d9aaa00e-dedc-4958-bfe3-4d1f10da5163",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 1000x1000 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZoAAAGkCAYAAAAIduO+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAATp0lEQVR4nO3dfYyU5dnw4XNZZACzu1bMEggLQtInIGhAljQKWhsNjYqRJ42tBq3RNilh5UMSo1T7IS1s6Ach0YpZ0hBaC/JHa6SJtiU2ghSJsAU1bQNpTWSrJdTG7KImo8A8f/Ttvt0idgf2ZGaW40jumL2dm/vMhfLLNTPM1JVKpVIAQJIhlR4AgMFNaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASFWzoXniiSdi4sSJMXz48Jg5c2a89NJLlR6pqrS3t8esWbOioaEhmpubY/78+XHw4MFKj1X12tvbo66uLpYtW1bpUarSW2+9FXfeeWeMGjUqRo4cGdOnT4/Ozs5Kj1VVjh8/Ho888khMnDgxRowYEZMmTYqVK1fGyZMnKz1axdRkaLZu3RrLli2Lhx9+OPbv3x/XXHNN3HjjjXH48OFKj1Y1duzYEW1tbbFnz57Yvn17HD9+PObOnRvvv/9+pUerWnv37o2Ojo644oorKj1KVXr33Xdj9uzZccEFF8Tzzz8ff/zjH+OHP/xhXHTRRZUeraqsWbMmnnzyyXj88cfjT3/6U3zve9+L73//+/HYY49VerSKqavFD9X8zGc+E1deeWWsX7++99yUKVNi/vz50d7eXsHJqtff//73aG5ujh07dsS1115b6XGqznvvvRdXXnllPPHEE/Hd7343pk+fHuvWrav0WFXloYceit/97neePfgv5s2bF6NHj44f//jHvee+8IUvxMiRI+OnP/1pBSernJrb0Xz44YfR2dkZc+fO7XN+7ty5sXv37gpNVf26u7sjIuLiiy+u8CTVqa2tLW6++ea44YYbKj1K1dq2bVu0trbGbbfdFs3NzTFjxozYsGFDpceqOnPmzIkXXnghDh06FBERr776auzatStuuummCk9WOUMrPUC53nnnnThx4kSMHj26z/nRo0fHkSNHKjRVdSuVSrF8+fKYM2dOTJs2rdLjVJ2nn346Ojs7Y9++fZUepaq98cYbsX79+li+fHl8/etfj1deeSWWLFkShUIhvvzlL1d6vKrx4IMPRnd3d0yePDnq6+vjxIkTsWrVqrjjjjsqPVrF1Fxo/qWurq7Pz6VS6ZRz/NN9990Xr732WuzatavSo1Sdrq6uWLp0afzmN7+J4cOHV3qcqnby5MlobW2N1atXR0TEjBkz4g9/+EOsX79eaP7N1q1b46mnnorNmzfH1KlT48CBA7Fs2bIYO3Zs3H333ZUeryJqLjSXXHJJ1NfXn7J7OXr06Cm7HCIWL14c27Zti507d8a4ceMqPU7V6ezsjKNHj8bMmTN7z504cSJ27twZjz/+eBSLxaivr6/ghNVjzJgxcdlll/U5N2XKlPj5z39eoYmq0wMPPBAPPfRQ3H777RERcfnll8ebb74Z7e3t521oau41mmHDhsXMmTNj+/btfc5v3749rr766gpNVX1KpVLcd9998Ytf/CJ++9vfxsSJEys9UlW6/vrr4/XXX48DBw70Hq2trbFgwYI4cOCAyPyb2bNnn/IW+UOHDsWECRMqNFF1+uCDD2LIkL5/tNbX15/Xb2+uuR1NRMTy5cvjrrvuitbW1rjqqquio6MjDh8+HAsXLqz0aFWjra0tNm/eHM8++2w0NDT07gCbmppixIgRFZ6uejQ0NJzyutWFF14Yo0aN8nrWf7j//vvj6quvjtWrV8cXv/jFeOWVV6KjoyM6OjoqPVpVueWWW2LVqlUxfvz4mDp1auzfvz/Wrl0b9957b6VHq5xSjfrRj35UmjBhQmnYsGGlK6+8srRjx45Kj1RVIuJjj40bN1Z6tKr32c9+trR06dJKj1GVfvnLX5amTZtWKhQKpcmTJ5c6OjoqPVLV6enpKS1durQ0fvz40vDhw0uTJk0qPfzww6VisVjp0SqmJv8eDQC1o+ZeowGgtggNAKmEBoBUQgNAKqEBIJXQAJCqZkNTLBbj29/+dhSLxUqPUvWsVf9Yp/6xTv1nrf6pZv8eTU9PTzQ1NUV3d3c0NjZWepyqZq36xzr1j3XqP2v1TzW7owGgNggNAKnO+Ydqnjx5Mt5+++1oaGg4q++P6enp6fNPTs9a9Y916h/r1H+Dfa1KpVIcO3Ysxo4de8onVv+7c/4azV//+tdoaWk5l7cEIFFXV9cnft/VOd/RNDQ0RETEnLgphsYF5/r2DFLPHHq90iOc4n//5/JKjwCpjsdHsSue6/1z/XTOeWj+9XTZ0LgghtYJDQOjsaH6Xm703zeD3v97Puy/vQxSff93AjCoCA0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAqjMKzRNPPBETJ06M4cOHx8yZM+Oll14a6LkAGCTKDs3WrVtj2bJl8fDDD8f+/fvjmmuuiRtvvDEOHz6cMR8ANa7s0Kxduza+8pWvxFe/+tWYMmVKrFu3LlpaWmL9+vUZ8wFQ48oKzYcffhidnZ0xd+7cPufnzp0bu3fv/thrisVi9PT09DkAOH+UFZp33nknTpw4EaNHj+5zfvTo0XHkyJGPvaa9vT2ampp6D1/jDHB+OaM3A/znt6mVSqXTfsPaihUroru7u/fo6uo6k1sCUKPK+irnSy65JOrr60/ZvRw9evSUXc6/FAqFKBQKZz4hADWtrB3NsGHDYubMmbF9+/Y+57dv3x5XX331gA4GwOBQ1o4mImL58uVx1113RWtra1x11VXR0dERhw8fjoULF2bMB0CNKzs0X/rSl+If//hHrFy5Mv72t7/FtGnT4rnnnosJEyZkzAdAjSs7NBERixYtikWLFg30LAAMQj7rDIBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmGVnoAGAifHzu90iMAp2FHA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFKVFZr29vaYNWtWNDQ0RHNzc8yfPz8OHjyYNRsAg0BZodmxY0e0tbXFnj17Yvv27XH8+PGYO3duvP/++1nzAVDjyvris1/96ld9ft64cWM0NzdHZ2dnXHvttQM6GACDw1l9w2Z3d3dERFx88cWnfUyxWIxisdj7c09Pz9ncEoAac8ZvBiiVSrF8+fKYM2dOTJs27bSPa29vj6ampt6jpaXlTG8JQA0649Dcd9998dprr8WWLVs+8XErVqyI7u7u3qOrq+tMbwlADTqjp84WL14c27Zti507d8a4ceM+8bGFQiEKhcIZDQdA7SsrNKVSKRYvXhzPPPNMvPjiizFx4sSsuQAYJMoKTVtbW2zevDmeffbZaGhoiCNHjkRERFNTU4wYMSJlQABqW1mv0axfvz66u7vjuuuuizFjxvQeW7duzZoPgBpX9lNnAFAOn3UGQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQKqhlR4ABsKv3z5Q6RFO8fmx0ys9AlQFOxoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQ6qxC097eHnV1dbFs2bIBGgeAweaMQ7N3797o6OiIK664YiDnAWCQOaPQvPfee7FgwYLYsGFDfOpTnxromQAYRM4oNG1tbXHzzTfHDTfc8F8fWywWo6enp88BwPmj7K9yfvrpp6OzszP27dvXr8e3t7fHo48+WvZgAAwOZe1ourq6YunSpfGzn/0shg8f3q9rVqxYEd3d3b1HV1fXGQ0KQG0qa0fT2dkZR48ejZkzZ/aeO3HiROzcuTMef/zxKBaLUV9f3+eaQqEQhUJhYKYFoOaUFZrrr78+Xn/99T7n7rnnnpg8eXI8+OCDp0QGAMoKTUNDQ0ybNq3PuQsvvDBGjRp1ynkAiPDJAAAkK/tdZ//pxRdfHIAxABis7GgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQamilB4CB8Pmx0ys9AnAadjQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgVdmheeutt+LOO++MUaNGxciRI2P69OnR2dmZMRsAg0BZ30fz7rvvxuzZs+Nzn/tcPP/889Hc3Bx/+ctf4qKLLkoaD4BaV1Zo1qxZEy0tLbFx48bec5deeulAzwTAIFLWU2fbtm2L1tbWuO2226K5uTlmzJgRGzZs+MRrisVi9PT09DkAOH+UFZo33ngj1q9fH5/+9Kfj17/+dSxcuDCWLFkSP/nJT057TXt7ezQ1NfUeLS0tZz00ALWjrlQqlfr74GHDhkVra2vs3r2799ySJUti79698fLLL3/sNcViMYrFYu/PPT090dLSEtfFrTG07oKzGB2ASjpe+ihejGeju7s7GhsbT/u4snY0Y8aMicsuu6zPuSlTpsThw4dPe02hUIjGxsY+BwDnj7JCM3v27Dh48GCfc4cOHYoJEyYM6FAADB5lheb++++PPXv2xOrVq+PPf/5zbN68OTo6OqKtrS1rPgBqXFmhmTVrVjzzzDOxZcuWmDZtWnznO9+JdevWxYIFC7LmA6DGlfX3aCIi5s2bF/PmzcuYBYBByGedAZBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0CqskJz/PjxeOSRR2LixIkxYsSImDRpUqxcuTJOnjyZNR8ANW5oOQ9es2ZNPPnkk7Fp06aYOnVq7Nu3L+65555oamqKpUuXZs0IQA0rKzQvv/xy3HrrrXHzzTdHRMSll14aW7ZsiX379qUMB0DtK+upszlz5sQLL7wQhw4dioiIV199NXbt2hU33XTTaa8pFovR09PT5wDg/FHWjubBBx+M7u7umDx5ctTX18eJEydi1apVcccdd5z2mvb29nj00UfPelAAalNZO5qtW7fGU089FZs3b47f//73sWnTpvjBD34QmzZtOu01K1asiO7u7t6jq6vrrIcGoHaUtaN54IEH4qGHHorbb789IiIuv/zyePPNN6O9vT3uvvvuj72mUChEoVA4+0kBqEll7Wg++OCDGDKk7yX19fXe3gzAaZW1o7nlllti1apVMX78+Jg6dWrs378/1q5dG/fee2/WfADUuLJC89hjj8U3vvGNWLRoURw9ejTGjh0bX/va1+Kb3/xm1nwA1Li6UqlUOpc37Onpiaamprgubo2hdRecy1sDMICOlz6KF+PZ6O7ujsbGxtM+zmedAZBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBIJXQAJBKaABIJTQApBIaAFIJDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKQSGgBSCQ0AqYQGgFRCA0AqoQEgldAAkEpoAEglNACkEhoAUgkNAKmEBoBUQgNAKqEBINXQc33DUqkUERHH46OI0rm+OwAD5Xh8FBH//8/10znnoTl27FhEROyK5871rQFIcOzYsWhqajrtv68r/bcUDbCTJ0/G22+/HQ0NDVFXV3fGv05PT0+0tLREV1dXNDY2DuCEg4+16h/r1D/Wqf8G+1qVSqU4duxYjB07NoYMOf0rMed8RzNkyJAYN27cgP16jY2Ng/I3MIO16h/r1D/Wqf8G81p90k7mX7wZAIBUQgNAqpoNTaFQiG9961tRKBQqPUrVs1b9Y536xzr1n7X6p3P+ZgAAzi81u6MBoDYIDQCphAaAVEIDQCqhASCV0ACQSmgASCU0AKT6P4WNn2Cyyo9WAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 480x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.matshow(conf)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f01d08-69de-46e4-a8c6-919dbbd39fef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97a8813-1447-450c-b190-5d85a6a42da7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
